# HETA_models.py
# Load pretrained decoder-only LMs configured for HETA:
# - returns (model, tokenizer)
# - model is in eval mode with output_attentions=True
# - tokenizer has a valid pad_token set
# - uses device_map="auto" and bfloat16 when possible

from typing import Tuple, Dict
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# ---- Utility -----------------------------------------------------------------

def _std_tokenizer(tok):
    # Ensure a pad token; fall back to EOS if needed.
    if tok.pad_token is None:
        if tok.eos_token is not None:
            tok.pad_token = tok.eos_token
        else:
            tok.add_special_tokens({'pad_token': '[PAD]'})
    tok.truncation_side = "left"  # stable for next-token attribution on long contexts
    return tok

def _load_causal_lm(
    model_name: str,
    trust_remote_code: bool = True,
    attn_outputs: bool = True,
    device_map: str = "auto",
    dtype = None,
) -> Tuple[torch.nn.Module, AutoTokenizer]:
    if dtype is None:
        # Prefer bf16 if available; otherwise float16 on CUDA; else float32.
        if torch.cuda.is_available():
            if torch.cuda.is_bf16_supported():
                dtype = torch.bfloat16
            else:
                dtype = torch.float16
        else:
            dtype = torch.float32

    tok = AutoTokenizer.from_pretrained(model_name, trust_remote_code=trust_remote_code, use_fast=True)
    tok = _std_tokenizer(tok)

    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=dtype,
        device_map=device_map,
        trust_remote_code=trust_remote_code,
        low_cpu_mem_usage=True
    )
    # Some tokenizers changed pad_token after model load -> resize if needed
    if hasattr(model, "resize_token_embeddings") and (tok.vocab_size != model.get_input_embeddings().weight.size(0)):
        model.resize_token_embeddings(len(tok))

    model.config.output_attentions = bool(attn_outputs)
    model.eval()
    return model, tok

# ---- Model-specific loaders (names reflect paper baselines) -------------------

def load_gpt2() -> Tuple[torch.nn.Module, AutoTokenizer]:
    return _load_causal_lm("gpt2")

def load_gptj_6b() -> Tuple[torch.nn.Module, AutoTokenizer]:
    return _load_causal_lm("EleutherAI/gpt-j-6B")

def load_phi3_14b() -> Tuple[torch.nn.Module, AutoTokenizer]:
    # Microsoft Phi-3 Medium 14B Instruct (HF name may vary; adjust if needed)
    return _load_causal_lm("microsoft/Phi-3-medium-4k-instruct")

def load_llama31_70b() -> Tuple[torch.nn.Module, AutoTokenizer]:
    # Requires access acceptance on HF; uses bf16 on A100 80GB class GPUs.
    return _load_causal_lm("meta-llama/Llama-3.1-70B")

def load_qwen25_3b() -> Tuple[torch.nn.Module, AutoTokenizer]:
    return _load_causal_lm("Qwen/Qwen2.5-3B-Instruct")

# (Optional) other paper baselines commonly used in open-source experiments
def load_opt_66b() -> Tuple[torch.nn.Module, AutoTokenizer]:
    return _load_causal_lm("facebook/opt-66b")

def load_llama31_8b() -> Tuple[torch.nn.Module, AutoTokenizer]:
    return _load_causal_lm("meta-llama/Llama-3.1-8B")

# ---- Bulk loader -------------------------------------------------------------

def load_all() -> Dict[str, Tuple[torch.nn.Module, AutoTokenizer]]:
    return {
        "gpt2": load_gpt2(),
        "gpt-j-6b": load_gptj_6b(),
        "phi-3-14b": load_phi3_14b(),
        "llama-3.1-70b": load_llama31_70b(),
        "qwen2.5-3b": load_qwen25_3b(),
        # Uncomment if you need these:
        # "llama-3.1-8b": load_llama31_8b(),
        # "opt-66b": load_opt_66b(),
    }

# ---- Smoke test --------------------------------------------------------------

if __name__ == "__main__":
    models = load_all()
    for name, (model, tok) in models.items():
        v = model.config.vocab_size
        d = model.get_input_embeddings().weight.shape[-1]
        print(f"[OK] {name:16s} | vocab={v} | d={d} | dtype={next(model.parameters()).dtype} | pad={tok.pad_token!r}")
